# -*- coding: utf-8 -*-
"""
Message stores status info of the current iteration.
"""
import itertools
import time

from niftynet.engine.application_variables import CONSOLE, TF_SUMMARIES
from niftynet.engine.signal import TRAIN, VALID, INFER
from niftynet.utilities.util_common import look_up_operations

CONSOLE_FORMAT = "{} iter {}, {} ({:3f}s)"
SUPPORTED_PHASES = {TRAIN, VALID, INFER}


class IterationMessage(object):
    """
    This class consists of network variables and operations at each iteration.
    It is generated by the application engine but can be modified by the
    application as well.
    """
    _current_iter = 0
    _current_iter_tic = 0
    _current_iter_toc = 0
    _current_iter_output = None

    _data_feed_dict = None
    _ops_to_run = None
    _phase = TRAIN

    _should_stop = None

    @property
    def current_iter(self):
        """
        Current iteration index
        can be used to create complex schedule for the
        iterative training/validation/inference procedure.

        :return: integer of iteration
        """
        return self._current_iter

    @current_iter.setter
    def current_iter(self, value):
        self._current_iter = int(value)
        self._current_iter_tic = time.time()
        self._current_iter_output = None

    @property
    def ops_to_run(self):
        """
        operations (tf graph elements) to be fed into
        ``session.run(...)``. This is currently mainly used
        for passing network gradient updates ops to ``session.run``.

        To modify the operations, assigns ``self.ops_to_run``

        :return: a copy of the operation dictionary
        """
        if self._ops_to_run is None:
            self._ops_to_run = {}
        assert isinstance(self._ops_to_run, dict), \
            'ops to run should be a dictionary'
        return self._ops_to_run

    @ops_to_run.setter
    def ops_to_run(self, value):
        self._ops_to_run = value

    @property
    def data_feed_dict(self):
        """
        A dictionary that maps graph elements to values
        to be fed into ``session.run(...)`` as feed_dict parameter

        :return: dictionary of operations
        """
        if self._data_feed_dict is None:
            self._data_feed_dict = {}
        return self._data_feed_dict

    @data_feed_dict.setter
    def data_feed_dict(self, value):
        assert isinstance(value, dict), \
            'data_feed_dict should a dictionary of placeholders:values'
        self._data_feed_dict = value

    @property
    def current_iter_output(self):
        """
        This property stores graph output received
        by running ``session.run()``.

        :return:
        """
        return self._current_iter_output

    @current_iter_output.setter
    def current_iter_output(self, value):
        self._current_iter_output = value
        self._current_iter_toc = time.time()

    @property
    def should_stop(self):
        """
        Engine check this property after each iteration

        This could be modified in by application
        ``application.set_iteration_update()``
        to create training schedules such as early stopping.

        :return: None or a handler that requested to stop the loop
        """
        return self._should_stop

    @should_stop.setter
    def should_stop(self, value):
        self._should_stop = value

    @property
    def phase(self):
        """
        A string indicating the phase in train/validation/inference

        :return:
        """
        return self._phase

    @phase.setter
    def phase(self, value):
        self._phase = look_up_operations(value, SUPPORTED_PHASES)

    @property
    def is_training(self):
        """

        :return: boolean value indicating if the phase is training
        """
        return self.phase == TRAIN

    @property
    def is_validation(self):
        """

        :return: boolean value indicating if the phase is validation
        """
        return self.phase == VALID

    @property
    def is_inference(self):
        """

        :return: boolean value indicating if the phase is inference
        """
        return self.phase == INFER

    @property
    def iter_duration(self):
        """
        measuring time used
        from setting self.current_iter to setting self.current_iter_output

        :return: time duration of an iteration
        """

        current_toc = max(self._current_iter_toc, self._current_iter_tic)
        return current_toc - self._current_iter_tic

    def to_console_string(self):
        """
        converting current_iter_output to string, for console displaying

        :return: summary string
        """
        summary_indentation = "    " if self.is_validation else ""
        summary_format = summary_indentation + CONSOLE_FORMAT
        try:
            console_content = self.current_iter_output.get(CONSOLE, '')
        except AttributeError:
            console_content = "print to console -- set current_iter_output " \
                              "to a dictionary of {CONSOLE: 'content'}."
        result_str = _console_vars_to_str(console_content)
        summary = summary_format.format(
            self.phase, self.current_iter, result_str, self.iter_duration)
        return summary

    def to_tf_summary(self, writer=None):
        """
        converting current_iter_output to tf summary and write to ``writer``

        :param writer: writer instance for summary output
        :return:
        """
        if writer is None:
            return
        try:
            summary = self.current_iter_output.get(TF_SUMMARIES, {})
        except AttributeError:
            summary = None

        if not summary:
            return
        writer.add_summary(summary, self.current_iter)


class IterationMessageGenerator(object):
    """
    Classes provides an iteration message generator function.
    The generator should yield IterationMessage instances.
    """

    def __init__(self,
                 initial_iter=0,
                 final_iter=0,
                 validation_every_n=0,
                 validation_max_iter=0,
                 is_training_action=True,
                 **_unused):
        self.initial_iter = max(initial_iter, -1)
        self.final_iter = max(final_iter, self.initial_iter)
        self.validation_every_n = validation_every_n
        self.validation_max_iter = validation_max_iter
        self.is_training_action = is_training_action

    def __call__(self):
        if not self.is_training_action:
            return _infer_iter_generator()
        return _train_iter_generator(
            initial_iter=self.initial_iter,
            final_iter=self.final_iter,
            validation_every_n=self.validation_every_n,
            validation_max_iter=self.validation_max_iter)


def _infer_iter_generator():
    """
    This generator yields infinite number of infer iterations.

    :return: iteration message instances
    """
    infer_iterations = _iter_msg_generator(itertools.count(), INFER)
    for infer_iter_msg in infer_iterations:
        yield infer_iter_msg


def _train_iter_generator(initial_iter=0,
                          final_iter=0,
                          validation_every_n=0,
                          validation_max_iter=0):
    """
    This generator yields a sequence of interleaved training and validation
    iterations.

    :param initial_iter: starting iteration of the training sequence
    :param final_iter: ending iteration of the training sequence
    :param validation_every_n: validation at every n training
    :param validation_max_iter: number of validation iterations
    :return: iteration message instances
    """
    train_iterations = _iter_msg_generator(
        range(initial_iter + 1, final_iter + 1), TRAIN)
    for train_iter_msg in train_iterations:
        yield train_iter_msg
        current_iter = train_iter_msg.current_iter
        if current_iter > 0 and validation_every_n > 0 and \
                current_iter % validation_every_n == 0:
            # generating validation iterations without changing the current
            # iteration number.
            valid_iterations = _iter_msg_generator(
                [current_iter] * validation_max_iter, VALID)
            for valid_iter_msg in valid_iterations:
                yield valid_iter_msg


def _iter_msg_generator(count_generator, phase):
    """
    Generate a numbered sequence of IterationMessage objects
    with phase-appropriate signals.
    count_generator is an iterable object yielding iteration numbers
    phase is one of TRAIN, VALID or INFER
    """
    for iter_i in count_generator:
        iter_msg = IterationMessage()
        iter_msg.current_iter, iter_msg.phase = iter_i, phase
        yield iter_msg


def _console_vars_to_str(console_dict):
    """
    Printing values of variable evaluations to command line output.
    """
    if not console_dict:
        return ''
    if isinstance(console_dict, dict):
        console_str = ', '.join('{}={}'.format(key, val)
                                for (key, val) in console_dict.items())
    else:
        console_str = '{}'.format(console_dict)
    return console_str
